{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Intel® Extension for Scikit-learn RandomForestClassifier for rain in Australia dataset\n",
    "To predict will it rain the next day."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from timeit import default_timer as timer\n",
    "from IPython.display import HTML\n",
    "import warnings\n",
    "\n",
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Date</th>\n",
       "      <th>Location</th>\n",
       "      <th>MinTemp</th>\n",
       "      <th>MaxTemp</th>\n",
       "      <th>Rainfall</th>\n",
       "      <th>Evaporation</th>\n",
       "      <th>Sunshine</th>\n",
       "      <th>WindGustDir</th>\n",
       "      <th>WindGustSpeed</th>\n",
       "      <th>WindDir9am</th>\n",
       "      <th>...</th>\n",
       "      <th>Humidity9am</th>\n",
       "      <th>Humidity3pm</th>\n",
       "      <th>Pressure9am</th>\n",
       "      <th>Pressure3pm</th>\n",
       "      <th>Cloud9am</th>\n",
       "      <th>Cloud3pm</th>\n",
       "      <th>Temp9am</th>\n",
       "      <th>Temp3pm</th>\n",
       "      <th>RainToday</th>\n",
       "      <th>RainTomorrow</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2008-12-01</td>\n",
       "      <td>Albury</td>\n",
       "      <td>13.4</td>\n",
       "      <td>22.9</td>\n",
       "      <td>0.6</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>W</td>\n",
       "      <td>44.0</td>\n",
       "      <td>W</td>\n",
       "      <td>...</td>\n",
       "      <td>71.0</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1007.7</td>\n",
       "      <td>1007.1</td>\n",
       "      <td>8.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>16.9</td>\n",
       "      <td>21.8</td>\n",
       "      <td>No</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2008-12-02</td>\n",
       "      <td>Albury</td>\n",
       "      <td>7.4</td>\n",
       "      <td>25.1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>WNW</td>\n",
       "      <td>44.0</td>\n",
       "      <td>NNW</td>\n",
       "      <td>...</td>\n",
       "      <td>44.0</td>\n",
       "      <td>25.0</td>\n",
       "      <td>1010.6</td>\n",
       "      <td>1007.8</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>17.2</td>\n",
       "      <td>24.3</td>\n",
       "      <td>No</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2008-12-03</td>\n",
       "      <td>Albury</td>\n",
       "      <td>12.9</td>\n",
       "      <td>25.7</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>WSW</td>\n",
       "      <td>46.0</td>\n",
       "      <td>W</td>\n",
       "      <td>...</td>\n",
       "      <td>38.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>1007.6</td>\n",
       "      <td>1008.7</td>\n",
       "      <td>NaN</td>\n",
       "      <td>2.0</td>\n",
       "      <td>21.0</td>\n",
       "      <td>23.2</td>\n",
       "      <td>No</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2008-12-04</td>\n",
       "      <td>Albury</td>\n",
       "      <td>9.2</td>\n",
       "      <td>28.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NE</td>\n",
       "      <td>24.0</td>\n",
       "      <td>SE</td>\n",
       "      <td>...</td>\n",
       "      <td>45.0</td>\n",
       "      <td>16.0</td>\n",
       "      <td>1017.6</td>\n",
       "      <td>1012.8</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>18.1</td>\n",
       "      <td>26.5</td>\n",
       "      <td>No</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2008-12-05</td>\n",
       "      <td>Albury</td>\n",
       "      <td>17.5</td>\n",
       "      <td>32.3</td>\n",
       "      <td>1.0</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>W</td>\n",
       "      <td>41.0</td>\n",
       "      <td>ENE</td>\n",
       "      <td>...</td>\n",
       "      <td>82.0</td>\n",
       "      <td>33.0</td>\n",
       "      <td>1010.8</td>\n",
       "      <td>1006.0</td>\n",
       "      <td>7.0</td>\n",
       "      <td>8.0</td>\n",
       "      <td>17.8</td>\n",
       "      <td>29.7</td>\n",
       "      <td>No</td>\n",
       "      <td>No</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 23 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         Date Location  MinTemp  MaxTemp  Rainfall  Evaporation  Sunshine  \\\n",
       "0  2008-12-01   Albury     13.4     22.9       0.6          NaN       NaN   \n",
       "1  2008-12-02   Albury      7.4     25.1       0.0          NaN       NaN   \n",
       "2  2008-12-03   Albury     12.9     25.7       0.0          NaN       NaN   \n",
       "3  2008-12-04   Albury      9.2     28.0       0.0          NaN       NaN   \n",
       "4  2008-12-05   Albury     17.5     32.3       1.0          NaN       NaN   \n",
       "\n",
       "  WindGustDir  WindGustSpeed WindDir9am  ... Humidity9am  Humidity3pm  \\\n",
       "0           W           44.0          W  ...        71.0         22.0   \n",
       "1         WNW           44.0        NNW  ...        44.0         25.0   \n",
       "2         WSW           46.0          W  ...        38.0         30.0   \n",
       "3          NE           24.0         SE  ...        45.0         16.0   \n",
       "4           W           41.0        ENE  ...        82.0         33.0   \n",
       "\n",
       "   Pressure9am  Pressure3pm  Cloud9am  Cloud3pm  Temp9am  Temp3pm  RainToday  \\\n",
       "0       1007.7       1007.1       8.0       NaN     16.9     21.8         No   \n",
       "1       1010.6       1007.8       NaN       NaN     17.2     24.3         No   \n",
       "2       1007.6       1008.7       NaN       2.0     21.0     23.2         No   \n",
       "3       1017.6       1012.8       NaN       NaN     18.1     26.5         No   \n",
       "4       1010.8       1006.0       7.0       8.0     17.8     29.7         No   \n",
       "\n",
       "   RainTomorrow  \n",
       "0            No  \n",
       "1            No  \n",
       "2            No  \n",
       "3            No  \n",
       "4            No  \n",
       "\n",
       "[5 rows x 23 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = fetch_openml(data_id=46315, as_frame=True)\n",
    "df = data.frame\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Explore the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(145460, 23)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Show the dimensions of the dataset\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 145460 entries, 0 to 145459\n",
      "Data columns (total 23 columns):\n",
      " #   Column         Non-Null Count   Dtype  \n",
      "---  ------         --------------   -----  \n",
      " 0   Date           145460 non-null  object \n",
      " 1   Location       145460 non-null  object \n",
      " 2   MinTemp        143975 non-null  float64\n",
      " 3   MaxTemp        144199 non-null  float64\n",
      " 4   Rainfall       142199 non-null  float64\n",
      " 5   Evaporation    82670 non-null   float64\n",
      " 6   Sunshine       75625 non-null   float64\n",
      " 7   WindGustDir    135134 non-null  object \n",
      " 8   WindGustSpeed  135197 non-null  float64\n",
      " 9   WindDir9am     134894 non-null  object \n",
      " 10  WindDir3pm     141232 non-null  object \n",
      " 11  WindSpeed9am   143693 non-null  float64\n",
      " 12  WindSpeed3pm   142398 non-null  float64\n",
      " 13  Humidity9am    142806 non-null  float64\n",
      " 14  Humidity3pm    140953 non-null  float64\n",
      " 15  Pressure9am    130395 non-null  float64\n",
      " 16  Pressure3pm    130432 non-null  float64\n",
      " 17  Cloud9am       89572 non-null   float64\n",
      " 18  Cloud3pm       86102 non-null   float64\n",
      " 19  Temp9am        143693 non-null  float64\n",
      " 20  Temp3pm        141851 non-null  float64\n",
      " 21  RainToday      142199 non-null  object \n",
      " 22  RainTomorrow   142193 non-null  object \n",
      "dtypes: float64(16), object(7)\n",
      "memory usage: 25.5+ MB\n"
     ]
    }
   ],
   "source": [
    "# Show the summary of the dataset\n",
    "df.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Missing Values</th>\n",
       "      <th>Percentage (%)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Date</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Location</th>\n",
       "      <td>0</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MinTemp</th>\n",
       "      <td>1485</td>\n",
       "      <td>1.020899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MaxTemp</th>\n",
       "      <td>1261</td>\n",
       "      <td>0.866905</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Rainfall</th>\n",
       "      <td>3261</td>\n",
       "      <td>2.241853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Evaporation</th>\n",
       "      <td>62790</td>\n",
       "      <td>43.166506</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Sunshine</th>\n",
       "      <td>69835</td>\n",
       "      <td>48.009762</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WindGustDir</th>\n",
       "      <td>10326</td>\n",
       "      <td>7.098859</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WindGustSpeed</th>\n",
       "      <td>10263</td>\n",
       "      <td>7.055548</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WindDir9am</th>\n",
       "      <td>10566</td>\n",
       "      <td>7.263853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WindDir3pm</th>\n",
       "      <td>4228</td>\n",
       "      <td>2.906641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WindSpeed9am</th>\n",
       "      <td>1767</td>\n",
       "      <td>1.214767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WindSpeed3pm</th>\n",
       "      <td>3062</td>\n",
       "      <td>2.105046</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Humidity9am</th>\n",
       "      <td>2654</td>\n",
       "      <td>1.824557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Humidity3pm</th>\n",
       "      <td>4507</td>\n",
       "      <td>3.098446</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Pressure9am</th>\n",
       "      <td>15065</td>\n",
       "      <td>10.356799</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Pressure3pm</th>\n",
       "      <td>15028</td>\n",
       "      <td>10.331363</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Cloud9am</th>\n",
       "      <td>55888</td>\n",
       "      <td>38.421559</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Cloud3pm</th>\n",
       "      <td>59358</td>\n",
       "      <td>40.807095</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Temp9am</th>\n",
       "      <td>1767</td>\n",
       "      <td>1.214767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Temp3pm</th>\n",
       "      <td>3609</td>\n",
       "      <td>2.481094</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RainToday</th>\n",
       "      <td>3261</td>\n",
       "      <td>2.241853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RainTomorrow</th>\n",
       "      <td>3267</td>\n",
       "      <td>2.245978</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "               Missing Values  Percentage (%)\n",
       "Date                        0        0.000000\n",
       "Location                    0        0.000000\n",
       "MinTemp                  1485        1.020899\n",
       "MaxTemp                  1261        0.866905\n",
       "Rainfall                 3261        2.241853\n",
       "Evaporation             62790       43.166506\n",
       "Sunshine                69835       48.009762\n",
       "WindGustDir             10326        7.098859\n",
       "WindGustSpeed           10263        7.055548\n",
       "WindDir9am              10566        7.263853\n",
       "WindDir3pm               4228        2.906641\n",
       "WindSpeed9am             1767        1.214767\n",
       "WindSpeed3pm             3062        2.105046\n",
       "Humidity9am              2654        1.824557\n",
       "Humidity3pm              4507        3.098446\n",
       "Pressure9am             15065       10.356799\n",
       "Pressure3pm             15028       10.331363\n",
       "Cloud9am                55888       38.421559\n",
       "Cloud3pm                59358       40.807095\n",
       "Temp9am                  1767        1.214767\n",
       "Temp3pm                  3609        2.481094\n",
       "RainToday                3261        2.241853\n",
       "RainTomorrow             3267        2.245978"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Check the missing values and the percentage of missing values in each column\n",
    "missing_values = df.isnull().sum()\n",
    "missing_values_percentage = missing_values / df.shape[0] * 100\n",
    "missing_values_df = pd.DataFrame({\n",
    "    'Missing Values': missing_values,\n",
    "    'Percentage (%)': missing_values_percentage\n",
    "})\n",
    "missing_values_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(145460, 19)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Drop columns with more than 30% missing values\n",
    "df = df.dropna(thresh=df.shape[0]*0.7, axis=1)\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(142193, 19)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Drop rows with missing target value\n",
    "df = df.dropna(subset=['RainTomorrow'])\n",
    "df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encode the target variable\n",
    "df['RainTomorrow'] = df['RainTomorrow'].map({'No': 0, 'Yes': 1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "Index: 142193 entries, 0 to 145458\n",
      "Data columns (total 22 columns):\n",
      " #   Column         Non-Null Count   Dtype         \n",
      "---  ------         --------------   -----         \n",
      " 0   Date           142193 non-null  datetime64[ns]\n",
      " 1   Location       142193 non-null  object        \n",
      " 2   MinTemp        141556 non-null  float64       \n",
      " 3   MaxTemp        141871 non-null  float64       \n",
      " 4   Rainfall       140787 non-null  float64       \n",
      " 5   WindGustDir    132863 non-null  object        \n",
      " 6   WindGustSpeed  132923 non-null  float64       \n",
      " 7   WindDir9am     132180 non-null  object        \n",
      " 8   WindDir3pm     138415 non-null  object        \n",
      " 9   WindSpeed9am   140845 non-null  float64       \n",
      " 10  WindSpeed3pm   139563 non-null  float64       \n",
      " 11  Humidity9am    140419 non-null  float64       \n",
      " 12  Humidity3pm    138583 non-null  float64       \n",
      " 13  Pressure9am    128179 non-null  float64       \n",
      " 14  Pressure3pm    128212 non-null  float64       \n",
      " 15  Temp9am        141289 non-null  float64       \n",
      " 16  Temp3pm        139467 non-null  float64       \n",
      " 17  RainToday      140787 non-null  object        \n",
      " 18  RainTomorrow   142193 non-null  int64         \n",
      " 19  Year           142193 non-null  int64         \n",
      " 20  Month          142193 non-null  int64         \n",
      " 21  Day            142193 non-null  int64         \n",
      "dtypes: datetime64[ns](1), float64(12), int64(4), object(5)\n",
      "memory usage: 25.0+ MB\n"
     ]
    }
   ],
   "source": [
    "# Split the Date column into Year, Month, and Day\n",
    "df['Date'] = pd.to_datetime(df['Date'])\n",
    "df['Year'] = df['Date'].dt.year.astype('int64')\n",
    "df['Month'] = df['Date'].dt.month.astype('int64')\n",
    "df['Day'] = df['Date'].dt.day.astype('int64')\n",
    "\n",
    "df.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the features and the target\n",
    "X = df.drop(columns=['RainTomorrow', 'Date'])\n",
    "y = df['RainTomorrow']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Numerical Columns: ['MinTemp', 'MaxTemp', 'Rainfall', 'WindGustSpeed', 'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm', 'Pressure9am', 'Pressure3pm', 'Temp9am', 'Temp3pm', 'Year', 'Month', 'Day']\n",
      "Categorical Columns: ['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm', 'RainToday']\n"
     ]
    }
   ],
   "source": [
    "# Identify the numerical and categorical columns\n",
    "num_columns = X.select_dtypes(include=['int64', 'float64']).columns\n",
    "cat_columns = X.select_dtypes(include=['object']).columns\n",
    "\n",
    "print(f'Numerical Columns: {list(num_columns)}')\n",
    "print(f'Categorical Columns: {list(cat_columns)}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the numerical features\n",
    "\n",
    "# Impute missing values with the mean\n",
    "imputer_num = SimpleImputer(strategy='mean')\n",
    "X[num_columns] = imputer_num.fit_transform(X[num_columns])\n",
    "\n",
    "# Scale the numerical columns\n",
    "scaler = StandardScaler()\n",
    "X[num_columns] = scaler.fit_transform(X[num_columns])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess the categorical features\n",
    "\n",
    "# Impute missing values with the most frequent value\n",
    "imputer_cat = SimpleImputer(strategy='most_frequent')\n",
    "X[cat_columns] = imputer_cat.fit_transform(X[cat_columns])\n",
    "\n",
    "# Label encode the categorical columns\n",
    "encoder = LabelEncoder()\n",
    "for col in cat_columns:\n",
    "    X[col] = encoder.fit_transform(X[col])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "Index: 142193 entries, 0 to 145458\n",
      "Data columns (total 22 columns):\n",
      " #   Column         Non-Null Count   Dtype         \n",
      "---  ------         --------------   -----         \n",
      " 0   Date           142193 non-null  datetime64[ns]\n",
      " 1   Location       142193 non-null  object        \n",
      " 2   MinTemp        141556 non-null  float64       \n",
      " 3   MaxTemp        141871 non-null  float64       \n",
      " 4   Rainfall       140787 non-null  float64       \n",
      " 5   WindGustDir    132863 non-null  object        \n",
      " 6   WindGustSpeed  132923 non-null  float64       \n",
      " 7   WindDir9am     132180 non-null  object        \n",
      " 8   WindDir3pm     138415 non-null  object        \n",
      " 9   WindSpeed9am   140845 non-null  float64       \n",
      " 10  WindSpeed3pm   139563 non-null  float64       \n",
      " 11  Humidity9am    140419 non-null  float64       \n",
      " 12  Humidity3pm    138583 non-null  float64       \n",
      " 13  Pressure9am    128179 non-null  float64       \n",
      " 14  Pressure3pm    128212 non-null  float64       \n",
      " 15  Temp9am        141289 non-null  float64       \n",
      " 16  Temp3pm        139467 non-null  float64       \n",
      " 17  RainToday      140787 non-null  object        \n",
      " 18  RainTomorrow   142193 non-null  int64         \n",
      " 19  Year           142193 non-null  int64         \n",
      " 20  Month          142193 non-null  int64         \n",
      " 21  Day            142193 non-null  int64         \n",
      "dtypes: datetime64[ns](1), float64(12), int64(4), object(5)\n",
      "memory usage: 25.0+ MB\n"
     ]
    }
   ],
   "source": [
    "# Ensure all columns are numerical and no missing values\n",
    "df.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((113754, 20), (28439, 20), (113754,), (28439,))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Split the dataset into training and testing sets\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "X_train.shape, X_test.shape, y_train.shape, y_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Patch original Scikit-learn with Intel® Extension for Scikit-learn\n",
    "Intel® Extension for Scikit-learn (previously known as daal4py) contains drop-in replacement functionality for the stock Scikit-learn package. You can take advantage of the performance optimizations of Intel® Extension for Scikit-learn by adding just two lines of code before the usual Scikit-learn imports:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Intel(R) Extension for Scikit-learn* enabled (https://github.com/uxlfoundation/scikit-learn-intelex)\n"
     ]
    }
   ],
   "source": [
    "from sklearnex import patch_sklearn\n",
    "patch_sklearn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training of the RandomForestClassifier with Intel® Extension for Scikit-learn for Rain in Australia dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Intel® extension for Scikit-learn Training Time: 9.439 seconds\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "params = {\n",
    "    'n_estimators': 1000,\n",
    "    'criterion': 'gini',\n",
    "    'max_features': 'sqrt',\n",
    "    'n_jobs': -1\n",
    "}\n",
    "start = timer()\n",
    "patched_model = RandomForestClassifier(**params).fit(X_train, y_train)\n",
    "patched_train_time = timer() - start\n",
    "\n",
    "print(f\"Intel® extension for Scikit-learn Training Time: {patched_train_time:.3f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Predict and get a result of the RandomForestClassifier algorithm with Intel® Extension for Scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Intel® extension for Scikit-learn Accuracy: 0.8551\n"
     ]
    }
   ],
   "source": [
    "patched_y_pred = patched_model.predict(X_test)\n",
    "patched_accuracy = accuracy_score(y_test, patched_y_pred)\n",
    "\n",
    "print(f\"Intel® extension for Scikit-learn Accuracy: {patched_accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the same algorithm with original Scikit-learn\n",
    "\n",
    "In order to cancel optimizations, we use *unpatch_sklearn* and reimport the class RandomForestClassifier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearnex import unpatch_sklearn\n",
    "unpatch_sklearn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training of the RandomForestClassifier with original Scikit-learn for Rain in Australia dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Scikit-learn Training Time: 47.955 seconds\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "start = timer()\n",
    "ori_model = RandomForestClassifier(**params).fit(X_train, y_train)\n",
    "ori_train_time = timer() - start\n",
    "\n",
    "print(f\"Original Scikit-learn Training Time: {ori_train_time:.3f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Predict and get a result of the RandomForestClassifier algorithm with original Scikit-learn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original Scikit-learn Accuracy: 0.8549\n"
     ]
    }
   ],
   "source": [
    "ori_y_pred = ori_model.predict(X_test)\n",
    "ori_accuracy = accuracy_score(y_test, ori_y_pred)\n",
    "\n",
    "print(f\"Original Scikit-learn Accuracy: {ori_accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Original</th>\n",
       "      <th>Patched</th>\n",
       "      <th>Improvement (%)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Accuracy</th>\n",
       "      <td>0.8549</td>\n",
       "      <td>0.8551</td>\n",
       "      <td>0.02</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Training Time (s)</th>\n",
       "      <td>47.9546</td>\n",
       "      <td>9.4395</td>\n",
       "      <td>-80.32</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   Original  Patched  Improvement (%)\n",
       "Accuracy             0.8549   0.8551             0.02\n",
       "Training Time (s)   47.9546   9.4395           -80.32"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compare_df = pd.DataFrame({\n",
    "    'Original': [ori_accuracy, ori_train_time],\n",
    "    'Patched': [patched_accuracy, patched_train_time]\n",
    "}, index=['Accuracy', 'Training Time (s)'])\n",
    "\n",
    "for col in compare_df.columns:\n",
    "    compare_df[col] = compare_df[col].round(4)\n",
    "\n",
    "# Calculate the improvement in percentage\n",
    "compare_df['Improvement (%)'] = (compare_df['Patched'] - compare_df['Original']) / compare_df['Original'] * 100\n",
    "compare_df['Improvement (%)'] = compare_df['Improvement (%)'].round(2)\n",
    "\n",
    "compare_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<h3>Compare Accuracy of patched Scikit-learn and original</h3>Accuracy of patched Scikit-learn: 0.8550933577129998 <br>Accuracy of unpatched Scikit-learn: 0.8548823798305144 <br>Metrics ratio: 1.0002467917077986 <br><h3>With Scikit-learn-intelex patching you can:</h3><ul><li>Use your Scikit-learn code for training and prediction with minimal changes (a couple of lines of code);</li><li>Get comparable model quality</li><li>Get a <strong>5.1x</strong> speedup.</li></ul>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "HTML(\n",
    "    f\"<h3>Compare Accuracy of patched Scikit-learn and original</h3>\"\n",
    "    f\"Accuracy of patched Scikit-learn: {patched_accuracy} <br>\"\n",
    "    f\"Accuracy of unpatched Scikit-learn: {ori_accuracy} <br>\"\n",
    "    f\"Metrics ratio: {patched_accuracy/ori_accuracy} <br>\"\n",
    "    f\"<h3>With Scikit-learn-intelex patching you can:</h3>\"\n",
    "    f\"<ul>\"\n",
    "    f\"<li>Use your Scikit-learn code for training and prediction with minimal changes (a couple of lines of code);</li>\"\n",
    "    f\"<li>Get comparable model quality</li>\"\n",
    "    f\"<li>Get a <strong>{(ori_train_time/patched_train_time):.1f}x</strong> speedup.</li>\"\n",
    "    f\"</ul>\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
